import paddle.nn.functional as F
import parl
from paddle import nn


class DimOneModel(parl.Model):
    def __init__(self, act_dim, obs_dim):
        """
        :param act_dim:  动作空间维度
        :param obs_dim:  状态空间维度
        """
        super(DimOneModel, self).__init__()
        # 三个全连接层, 将obs_dim维度的输入转换为act_dim维度的输出, 作为Q值
        self.fc1 = nn.Linear(obs_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, act_dim)

    def forward(self, obs):
        """
        :param obs:
        :return: Q
        """
        # 将obs转换为Q值
        h1 = F.relu(self.fc1(obs))
        h2 = F.relu(self.fc2(h1))
        Q = self.fc3(h2)
        return Q
